import math

import torch
from torch import nn as nn
from torch.nn import functional as F

from .arch_util import flow_warp


class BasicModule(nn.Module):
    """Basic Module for SpyNet."""

    def __init__(self):
        super(BasicModule, self).__init__()

        self.basic_module = nn.Sequential(
            nn.Conv2d(
                in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3
            ),
            nn.ReLU(inplace=False),
            nn.Conv2d(
                in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3
            ),
            nn.ReLU(inplace=False),
            nn.Conv2d(
                in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3
            ),
            nn.ReLU(inplace=False),
            nn.Conv2d(
                in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3
            ),
            nn.ReLU(inplace=False),
            nn.Conv2d(
                in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3
            ),
        )

    def forward(self, tensor_input):
        return self.basic_module(tensor_input)


class SpyNet(nn.Module):
    """SpyNet architecture.

    Args:
        load_path (str): path for pretrained SpyNet. Default: None.
    """

    def __init__(self, load_path=None):
        super(SpyNet, self).__init__()
        self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
        if load_path:
            self.load_state_dict(
                torch.load(load_path, map_location=lambda storage, loc: storage)[
                    "params"
                ]
            )

        self.register_buffer(
            "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
        )
        self.register_buffer(
            "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
        )

    def preprocess(self, tensor_input):
        tensor_output = (tensor_input - self.mean) / self.std
        return tensor_output

    def process(self, ref, supp):
        flow = []
        W = ref.shape[-1]

        ref = [self.preprocess(ref)]
        supp = [self.preprocess(supp)]

        r = 5 if W > 2**5 else int(math.log2(W) - 1)
        # print(W,r)

        for level in range(r):
            # print(ref[0].size())
            ref.insert(
                0,
                F.avg_pool2d(
                    input=ref[0], kernel_size=2, stride=2, count_include_pad=False
                ),
            )
            supp.insert(
                0,
                F.avg_pool2d(
                    input=supp[0], kernel_size=2, stride=2, count_include_pad=False
                ),
            )

        flow = ref[0].new_zeros(
            [
                ref[0].size(0),
                2,
                int(math.floor(ref[0].size(2) / 2.0)),
                int(math.floor(ref[0].size(3) / 2.0)),
            ]
        )

        for level in range(len(ref)):
            upsampled_flow = (
                F.interpolate(
                    input=flow, scale_factor=2, mode="bilinear", align_corners=True
                )
                * 2.0
            )

            if upsampled_flow.size(2) != ref[level].size(2):
                upsampled_flow = F.pad(
                    input=upsampled_flow, pad=[0, 0, 0, 1], mode="replicate"
                )
            if upsampled_flow.size(3) != ref[level].size(3):
                upsampled_flow = F.pad(
                    input=upsampled_flow, pad=[0, 1, 0, 0], mode="replicate"
                )

            flow = (
                self.basic_module[level](
                    torch.cat(
                        [
                            ref[level],
                            flow_warp(
                                supp[level],
                                upsampled_flow.permute(0, 2, 3, 1),
                                interp_mode="bilinear",
                                padding_mode="border",
                            ),
                            upsampled_flow,
                        ],
                        1,
                    )
                )
                + upsampled_flow
            )

        return flow

    def forward(self, ref, supp):
        assert ref.size() == supp.size()

        h, w = ref.size(2), ref.size(3)
        w_floor = math.floor(math.ceil(w / 32.0) * 32.0)
        h_floor = math.floor(math.ceil(h / 32.0) * 32.0)

        ref = F.interpolate(
            input=ref, size=(h_floor, w_floor), mode="bilinear", align_corners=False
        )
        supp = F.interpolate(
            input=supp, size=(h_floor, w_floor), mode="bilinear", align_corners=False
        )

        flow = F.interpolate(
            input=self.process(ref, supp),
            size=(h, w),
            mode="bilinear",
            align_corners=False,
        )

        flow[:, 0, :, :] *= float(w) / float(w_floor)
        flow[:, 1, :, :] *= float(h) / float(h_floor)

        return flow

    def get_flow(self, x):
        b, n, c, h, w = x.size()

        x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
        x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)

        flows_backward = self.forward(x_1, x_2).view(b, n - 1, 2, h, w)

        flows_forward = self.forward(x_2, x_1).view(b, n - 1, 2, h, w)

        return flows_forward, flows_backward
